# Copyright (c) 2025 Apple Inc. Licensed under MIT License.

"""The embedding atlas widget for notebooks"""

import pathlib
from typing import Any, Unpack

import duckdb

from .options import EmbeddingAtlasOptions, make_embedding_atlas_props
from .utils import arrow_to_bytes

try:
    import anywidget
    import traitlets
except ImportError:
    print(
        "⚠️ The widget depends on anywidget. Please run `pip install anywidget`, then try again."
    )
    raise


class EmbeddingAtlasWidget(anywidget.AnyWidget):
    """An Embedding Atlas widget in notebooks"""

    _esm = pathlib.Path(__file__).parent / "widget_static" / "anywidget" / "index.js"

    # The props to the embedding atlas component, internal use only
    _props = traitlets.Dict({}).tag(sync=True)

    # The state of the embedding atlas component, internal use only
    _state = traitlets.Any(None).tag(sync=True)
    _predicate = traitlets.Any(None).tag(sync=True)

    def __init__(
        self,
        data_frame: Any,
        *,
        connection: duckdb.DuckDBPyConnection | None = None,
        **options: Unpack[EmbeddingAtlasOptions],
    ):
        """
        Create an Embedding Atlas widget.

        Args:
            data_frame:
                A DataFrame/Arrow object to "register" with DuckDB.

            row_id:
                The column name for row id (if not specified, a row id column will be added).

            x:
                The column name for X axis in the embedding.

            y:
                The column name for Y axis in the embedding.

            text:
                The column name for the textual data.

            neighbors:
                The column name containing precomputed K-nearest neighbors for each point.
                Each value in the column should be a dictionary with the format:
                ``{ "ids": [id1, id2, ...], "distances": [distance1, distance2, ...] }``.

                - ``"ids"`` should be an array of row ids of the neighbors
                  (if ``row_id`` is specified, match the value in row_id, otherwise use the row index),
                  sorted by distance.
                - ``"distances"`` should contain the corresponding distances to each neighbor.

            labels:
                Labels for the embedding view. Set to string ``"automatic"`` to generate labels automatically, or ``"disabled"`` to disable auto labels.
                Automatic labels are generated by clustering the 2D density distribution and selecting
                representative keywords using TF-IDF ranking.
                You can also pass in a list of labels. Each label must contain ``x`` and ``y`` coordinates
                and ``text`` for the label content. Optionally, you may specify an integer ``level`` to roughly
                control the zoom level where the label appears, and `priority` for the label's priority.
                Higher priority labels have a better chance to appear when multiple labels overlap.

            stop_words:
                Stop words for automatic label generation.

            point_size:
                Override the default point size for the embedding view.

            show_table:
                Whether to display the data table when the widget opens.

            show_charts:
                Whether to display charts when the widget opens.

            show_embedding:
                Whether to display the embedding view when the widget opens.

            connection (DuckDBPyConnection, optional):
                A DuckDB connection. Defaults to duckdb.connect().
        """

        _ = data_frame  # used by DuckDB

        table_name = "embedding_atlas"
        row_id_column = options.get("row_id", "__row_id__")

        props = make_embedding_atlas_props(
            **(options | {"table": table_name, "row_id": row_id_column}),
        )

        if connection is None:
            connection = duckdb.connect()

        connection.sql(
            f"CREATE TEMPORARY TABLE {table_name} AS SELECT * FROM data_frame"
        )

        if options.get("row_id") is None:
            # Create the row_id_column if it does not exist.
            connection.sql(
                f"""
                ALTER TABLE {table_name} ADD COLUMN {row_id_column} INTEGER;
                CREATE TEMPORARY SEQUENCE row_id_sequence;
                UPDATE {table_name} SET {row_id_column} = nextval('row_id_sequence');
                DROP SEQUENCE row_id_sequence;
                """
            )

        super().__init__()

        self._props = props

        self._connection: duckdb.DuckDBPyConnection = connection
        self._table_name = table_name
        self.on_msg(self._handle_custom_msg)

    def selection(self, format: str = "dataframe") -> Any:
        """
        Returns the current selection in the widget.

        Args:
            format: the format of the returned selection, 'dataframe', 'arrow', or 'predicate'
        """
        if self._predicate is not None:
            self._connection.execute(
                f"SELECT * FROM {self._table_name} WHERE {self._predicate}"
            )
        else:
            self._connection.execute(f"SELECT * FROM {self._table_name}")
        if format == "dataframe":
            return self._connection.fetch_df()
        elif format == "arrow":
            return self._connection.fetch_arrow_table()
        else:
            raise ValueError(
                "invalid format, supported options are 'dataframe', 'arrow', and 'predicate'"
            )

    def _handle_custom_msg(self, content: dict, buffers: list):
        uuid = content["uuid"]
        sql = content["sql"]
        command = content["type"]

        try:
            if command == "arrow":
                result = self._connection.query(sql).arrow()
                buf = arrow_to_bytes(result)
                self.send({"type": "arrow", "uuid": uuid}, buffers=[buf])
            elif command == "exec":
                self._connection.execute(sql)
                self.send({"type": "exec", "uuid": uuid})
            elif command == "json":
                result = self._connection.query(sql).df()
                json = result.to_dict(orient="records")
                self.send({"type": "json", "uuid": uuid, "result": json})
            else:
                raise ValueError(f"Unknown command {command}")
        except Exception as e:
            self.send({"error": str(e), "uuid": uuid})
